{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_tabnet.multitask import TabNetMultiTaskClassifier\n",
    "\n",
    "import torch\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "\n",
    "\n",
    "import os\n",
    "import wget\n",
    "from pathlib import Path\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "%load_ext autoreload\n",
    "\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download census-income dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n",
    "dataset_name = 'census-income'\n",
    "out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out.parent.mkdir(parents=True, exist_ok=True)\n",
    "if out.exists():\n",
    "    print(\"File already exists.\")\n",
    "else:\n",
    "    print(\"Downloading file...\")\n",
    "    wget.download(url, out.as_posix())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data and split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv(out)\n",
    "target = ' <=50K'\n",
    "if \"Set\" not in train.columns:\n",
    "    train[\"Set\"] = np.random.choice([\"train\", \"valid\", \"test\"], p =[.8, .1, .1], size=(train.shape[0],))\n",
    "\n",
    "train_indices = train[train.Set==\"train\"].index\n",
    "valid_indices = train[train.Set==\"valid\"].index\n",
    "test_indices = train[train.Set==\"test\"].index"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simple preprocessing\n",
    "\n",
    "Label encode categorical features and fill empty cells."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nunique = train.nunique()\n",
    "types = train.dtypes\n",
    "\n",
    "categorical_columns = []\n",
    "categorical_dims =  {}\n",
    "for col in train.columns:\n",
    "    if types[col] == 'object' or nunique[col] < 200:\n",
    "        print(col, train[col].nunique())\n",
    "        l_enc = LabelEncoder()\n",
    "        train[col] = train[col].fillna(\"VV_likely\")\n",
    "        train[col] = l_enc.fit_transform(train[col].values)\n",
    "        categorical_columns.append(col)\n",
    "        categorical_dims[col] = len(l_enc.classes_)\n",
    "    else:\n",
    "        train.fillna(train.loc[train_indices, col].mean(), inplace=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define categorical features for categorical embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unused_feat = ['Set']\n",
    "\n",
    "features = [ col for col in train.columns if col not in unused_feat+[target]] \n",
    "\n",
    "cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]\n",
    "\n",
    "cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Network parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = TabNetMultiTaskClassifier(cat_idxs=cat_idxs,\n",
    "                       cat_dims=cat_dims,\n",
    "                       cat_emb_dim=1,\n",
    "                       optimizer_fn=torch.optim.Adam,\n",
    "                       optimizer_params=dict(lr=2e-2),\n",
    "                       scheduler_params={\"step_size\":50, # how to use learning rate scheduler\n",
    "                                         \"gamma\":0.9},\n",
    "                       scheduler_fn=torch.optim.lr_scheduler.StepLR,\n",
    "                       mask_type='entmax' # \"sparsemax\"\n",
    "                      )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NB_TASKS = 5 # Just a toy example to mimic multitask multiclassification problem\n",
    "\n",
    "X_train = train[features].values[train_indices]\n",
    "y_train = train[target].values[train_indices].reshape(-1, 1)\n",
    "y_train = np.hstack([y_train]*NB_TASKS)\n",
    "# Set random labels to the last task to show how this works\n",
    "y_train[:,-1] = np.random.randint(10, 15, y_train.shape[0]).astype(str)\n",
    "\n",
    "X_valid = train[features].values[valid_indices]\n",
    "y_valid = train[target].values[valid_indices].reshape(-1, 1)\n",
    "y_valid = np.hstack([y_valid]*NB_TASKS)\n",
    "# Set random labels to the last task to show how this works\n",
    "y_valid[:,-1] = np.random.randint(10, 15, y_valid.shape[0]).astype(str)\n",
    "\n",
    "X_test = train[features].values[test_indices]\n",
    "y_test = train[target].values[test_indices].reshape(-1, 1)\n",
    "y_test = np.hstack([y_test]*NB_TASKS)\n",
    "# Set random labels to the last task to show how this works\n",
    "y_test[:,-1] = np.random.randint(10, 15, y_test.shape[0]).astype(str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_epochs = 200 if not os.getenv(\"CI\", False) else 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "clf.fit(\n",
    "    X_train=X_train, y_train=y_train,\n",
    "    eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
    "    eval_name=['train', 'valid'],\n",
    "    max_epochs=max_epochs , patience=20,\n",
    "    batch_size=1024, virtual_batch_size=128,\n",
    "    num_workers=0,\n",
    "    drop_last=False,\n",
    "    loss_fn=[torch.nn.functional.cross_entropy]*NB_TASKS # Optional, just an example of list usage\n",
    ") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plt.plot(clf.history['loss'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot logloss\n",
    "plt.plot(clf.history['train_logloss'])\n",
    "plt.plot(clf.history['valid_logloss'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot learning rates\n",
    "plt.plot(clf.history['lr'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = clf.predict_proba(X_test)\n",
    "test_aucs = [roc_auc_score(y_score=task_pred[:,1], y_true=y_test[:, task_idx])\n",
    "             for task_idx, task_pred in enumerate(preds[:-1])] # don't compute on random last one\n",
    "\n",
    "print(f\"BEST VALID SCORE FOR {dataset_name} : {clf.best_cost}\")\n",
    "print(f\"FINAL AUC SCORES FOR {dataset_name} : {test_aucs}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predict_classes = clf.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## accessing class mapping\n",
    "clf.classes_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predict_classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf.target_mapper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ensemble_auc = roc_auc_score(y_score=np.mean(np.vstack([task_pred[:,1] for task_pred in preds]), axis=0),\n",
    "                             y_true=y_test[:,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ensemble_auc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save and load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save tabnet model\n",
    "saving_path_name = \"./MultiTaskClassifier_1\"\n",
    "saved_filepath = clf.save_model(saving_path_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define new model with basic parameters and load state dict weights\n",
    "loaded_clf = TabNetMultiTaskClassifier()\n",
    "loaded_clf.load_model(saved_filepath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_preds = loaded_clf.predict_proba(X_test)\n",
    "\n",
    "loaded_test_auc = [roc_auc_score(y_score=task_pred[:,1], y_true=y_test[:, task_idx])\n",
    "             for task_idx, task_pred in enumerate(loaded_preds[:-1])]\n",
    "\n",
    "print(f\"FINAL AUCS SCORE FOR {dataset_name} : {loaded_test_auc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert(test_aucs == loaded_test_auc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_clf.predict(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Global explainability : feat importance summing to 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf.feature_importances_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Local explainability and masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "explain_matrix, masks = clf.explain(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 3, figsize=(20,20))\n",
    "\n",
    "for i in range(3):\n",
    "    axs[i].imshow(masks[i][:50])\n",
    "    axs[i].set_title(f\"mask {i}\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
